package org.feng.interceptor;

import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.type.TypeHandlerRegistry;
import org.feng.util.TimeUtil;
import org.springframework.util.CollectionUtils;

import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.Date;
import java.util.List;
import java.util.Objects;

/**
 * mybatis拦截器拦截处理查询、更新的方法，mybatis-plus拦截器见：{@link MybatisPlusInterceptor}
 *
 * @version v1.0
 * @author: fengjinsong
 * @date: 2023年08月25日 23时23分
 */

@Intercepts({
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
})
@Slf4j
public class MybatisPrintSqlInterceptor implements Interceptor {

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        // 获取语句映射对象
        Object[] invocationArgs = invocation.getArgs();
        MappedStatement mappedStatement = (MappedStatement) invocationArgs[0];

        // 获取参数（条件）
        Object paramObject = null;
        // 2个以上的入参，也就是有额外的查询或更新条件
        if (invocationArgs.length > 1) {
            paramObject = invocationArgs[1];
        }

        BoundSql boundSql = mappedStatement.getBoundSql(paramObject);
        Configuration configuration = mappedStatement.getConfiguration();
        String mappedStatementId = mappedStatement.getId();
        // 开始执行时间
        long start = System.currentTimeMillis();
        // 执行方法
        Object returnValue = invocation.proceed();
        // 执行耗时
        long executeTime = System.currentTimeMillis() - start;
        // 拼接sql，参数注入
        String sql = concatSql(configuration, boundSql);
        // 打印sql
        logs(executeTime, sql, mappedStatementId);
        return returnValue;
    }


    private String concatSql(Configuration configuration, BoundSql boundSql) {
        Object parameterObject = boundSql.getParameterObject();
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        //替换空格、换行、tab缩进等
        String sql = boundSql.getSql().replaceAll("[\\s]+", " ");
        if (!CollectionUtils.isEmpty(parameterMappings) && Objects.nonNull(parameterObject)) {
            TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
            if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                sql = sql.replaceFirst("\\?", getParameterValue(parameterObject));
            } else {
                MetaObject metaObject = configuration.newMetaObject(parameterObject);
                for (ParameterMapping parameterMapping : parameterMappings) {
                    String propertyName = parameterMapping.getProperty();
                    if (metaObject.hasGetter(propertyName)) {
                        Object obj = metaObject.getValue(propertyName);
                        sql = sql.replaceFirst("\\?", getParameterValue(obj));
                    } else if (boundSql.hasAdditionalParameter(propertyName)) {
                        Object obj = boundSql.getAdditionalParameter(propertyName);
                        sql = sql.replaceFirst("\\?", getParameterValue(obj));
                    }
                }
            }
        }
        return sql;
    }

    private String getParameterValue(Object obj) {
        String value;
        if (obj instanceof String) {
            value = "'" + obj + "'";
        } else if (obj instanceof Date) {
            value = "'" + TimeUtil.defaultFormat(((Date) obj).toInstant()) + "'";
        } else if (obj instanceof LocalDateTime) {
            value = "'" + TimeUtil.defaultFormat((LocalDateTime) obj) + "'";
        } else if (obj instanceof LocalDate) {
            value = "'" + TimeUtil.defaultFormat((LocalDate) obj) + "'";
        } else {
            if (obj != null) {
                value = obj.toString();
            } else {
                value = "";
            }
        }
        return value.replace("$", "\\$");
    }

    private void logs(long time, String sql, String sqlId) {
        log.info("\r\n执行SQL：{} \r\n执行耗时：{}ms, 执行方法:{}", sql, time, sqlId);
    }

    @Override
    public Object plugin(Object target) {
        // 如果是Executor（执行增删改查操作），则拦截下来
        if (target instanceof Executor) {
            return Plugin.wrap(target, this);
        }
        return target;
    }
}
